{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "sys.path.append('.')\n",
    "import torch\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "torch.backends.cudnn.deterministic = False\n",
    "from diffuser.guides.policies import Policy\n",
    "import diffuser.datasets as datasets\n",
    "import diffuser.utils as utils\n",
    "from diffuser.models import GaussianDiffusionPB\n",
    "from datetime import datetime\n",
    "import os.path as osp\n",
    "from tqdm import tqdm\n",
    "from diffuser.utils.jupyter_utils import suppress_stdout\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=3)\n",
    "\n",
    "## You need to download 'maze2d-concave-base' from onedrive to launch\n",
    "class Parser(utils.Parser):\n",
    "    config: str = \"config/rm2d/rSmazeC43_concave_nw7_exp.py\"\n",
    "\n",
    "#---------------------------------- setup ----------------------------------#\n",
    "\n",
    "## training args\n",
    "args_train = Parser().parse_args('diffusion', from_jupyter=True)\n",
    "args = Parser().parse_args('plan', from_jupyter=True)\n",
    "\n",
    "args.savepath = None # osp.join(args.savepath, sub_dir)\n",
    "args.load_unseen_maze = True\n",
    "\n",
    "## load dataset here, dataset is a string: name of the env\n",
    "print('args.dataset', type(args.dataset), args.dataset)\n",
    "print('args.dataset_eval', type(args.dataset_eval), args.dataset_eval)\n",
    "use_normed_wallLoc = args_train.dataset_config.get('use_normed_wallLoc', False)\n",
    "\n",
    "## use the trained (seen) env or eval (unseen) env\n",
    "load_unseen_maze = args.load_unseen_maze # True False\n",
    "\n",
    "with suppress_stdout():\n",
    "    # # ---------- load normalizer ------------\n",
    "    plan_env_list = datasets.load_environment(args.dataset, is_eval=load_unseen_maze)\n",
    "    train_normalizer = utils.load_datasetNormalizer(plan_env_list.dataset_url, args_train, plan_env_list)\n",
    "\n",
    "    #---------------------------------- loading ----------------------------------#\n",
    "    ld_config = dict(env_instance=plan_env_list) \n",
    "    diffusion_experiment = utils.load_potential_diffusion_model(args.logbase, args.dataset, \\\n",
    "                args_train.exp_name, epoch='latest', ld_config=ld_config) \n",
    "\n",
    "    diffusion = diffusion_experiment.ema"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### create an env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''pick an env idx'''\n",
    "from diffuser.utils.jupyter_utils import create_every_wloc_43\n",
    "# wloc_source = 'from_env'\n",
    "wloc_source = 'user_input'\n",
    "if wloc_source == 'from_env':\n",
    "    env_id = 10\n",
    "    wallLoc_list = plan_env_list.recs_grp.save_center_and_mode(is_save=False)\n",
    "    # wloc = plan_env_list.wallLoc_list[env_id] ### wall locations of 2d square blocks\n",
    "    wloc_np = wallLoc_list[env_id] ## concave obstacle locations as model input\n",
    "    env = plan_env_list.create_single_env(env_id)\n",
    "elif wloc_source == 'user_input':\n",
    "    env_id = 5 # placholder\n",
    "    wloc_np = np.array([\n",
    "            [1.03, 2.82, 4.],\n",
    "            [1.77, 4.50, 1.],\n",
    "            [0.88, 0.9, 1.],\n",
    "            [3.09, 3.52, 4.],\n",
    "            [3.02, 1.42, 2.],\n",
    "            [4.30, 2.82, 1.],\n",
    "            [2.35, 2.58, 1.],\n",
    "        ])\n",
    "    wloc_np = np.array([\n",
    "       [1.036, 0.872, 4.],\n",
    "       [1.46 , 3.797, 4.],\n",
    "       [2.366, 1.987, 2.],\n",
    "       [2.853, 4.151, 3.],\n",
    "       [3.213, 0.676, 3.],\n",
    "       [3.6  , 2.92 , 3.],\n",
    "       [4.277, 1.938, 1.],\n",
    "    ])\n",
    "    \n",
    "\n",
    "    assert (1 <= wloc_np[:, 2]).all() and (wloc_np[:, 2] <= 4).all(), 'should be in {1,2,3,4}'\n",
    "    env_hExts = plan_env_list.model_list[0].wall_hExts #  np.array([0.25, 0.25])\n",
    "    wloc_np_xy = create_every_wloc_43(wloc_np, env_hExts[0], ) ## (21, 2)\n",
    "    # print(env_hExts)\n",
    "    plan_env_list.wallLoc_list[env_id] = wloc_np_xy\n",
    "    ## create the customized env\n",
    "    env = plan_env_list.create_env_by_pos(env_id, wloc_np_xy, env_hExts)\n",
    "\n",
    "print(wloc_np[:2]) ## [ [x, y, mode_idx], ..., ]\n",
    "env.render()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wloc_np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### planning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env.wall_hExts\n",
    "# plan_env_list.rand_rec_group.rec_loc_grp_list\n",
    "# type(diffusion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffuser.utils.rm2d_render import RandStaticMazeRenderer\n",
    "import imageio, time\n",
    "\n",
    "diffusion: GaussianDiffusionPB\n",
    "bs = 10\n",
    "horizon = 48 ## if the distance between start and goal is long, increase the horizon\n",
    "policy = Policy(diffusion, train_normalizer, use_ddim=True)\n",
    "diffusion.ddim_num_inference_steps = 8\n",
    "diffusion.condition_guidance_w = 2.0 ## guidance scale set to 2.0\n",
    "diffusion.horizon = horizon\n",
    "\n",
    "start = np.array( [ [ 1.22, 4.74 ] ], dtype=np.float32 ).repeat( bs, 0 )\n",
    "goal = np.array( [ [3.86, 0.20] ], dtype=np.float32  ).repeat( bs, 0 )\n",
    "\n",
    "\n",
    "wloc_tensor = utils.to_torch(wloc_np).reshape(1, -1).repeat( (bs, 1) )\n",
    "print( f'wloc_np: ', wloc_np.shape )\n",
    "\n",
    "cond = {0: start, \n",
    "        horizon-1: goal}\n",
    "\n",
    "tic = time.time()\n",
    "## samples: _, final trajectories, _, per_steps_state_trajs\n",
    "samples = policy(cond, batch_size=-1, wall_locations=wloc_tensor,use_normed_wallLoc=use_normed_wallLoc, return_diffusion=True)\n",
    "plan_time = time.time() - tic\n",
    "\n",
    "unnm_traj = samples[1].observations\n",
    "print(f'unnm_traj: {unnm_traj.shape}; plan time:{plan_time}')\n",
    "\n",
    "rd = RandStaticMazeRenderer(plan_env_list)\n",
    "rootdir = './visualization/'\n",
    "tmp_path = f'{rootdir}/concave_test.png'\n",
    "img = rd.composite( tmp_path, unnm_traj,  np.array([env_id]))\n",
    "img = imageio.imread(tmp_path)\n",
    "utils.plt_img(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# len(samples[2]) # empty, ignore\n",
    "# samples[2]\n",
    "# env.maze_size # [5, 5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### create a denoising video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pb_diff_envs.environment.static.rand_maze2d_renderer import RandMaze2DRenderer\n",
    "\n",
    "traj_idx = 0 # which traj\n",
    "\n",
    "x0_perstep = samples[3] # shape: B, n_steps, Horizon, Dim\n",
    "x0_perstep.shape\n",
    "\n",
    "n_steps = len(x0_perstep[0]) ## number of denosing steps + 1\n",
    "trajs_ps = [] ## per-step trajs\n",
    "imgs_ps = []\n",
    "\n",
    "mz_renderer = RandMaze2DRenderer(num_walls_c=[30,10], fig_dpi=200)\n",
    "mz_renderer.up_right = 'empty'\n",
    "mz_renderer.config['no_draw_traj_line'] = False\n",
    "mz_renderer.config['no_draw_col_info'] = True # False\n",
    "# mz_renderer.num_walls_c = [0, 30]\n",
    "\n",
    "for i_st in range(1, n_steps):\n",
    "    traj_ps = x0_perstep[traj_idx][i_st] #\n",
    "    trajs_ps.append(traj_ps)\n",
    "    ### -----\n",
    "    env_wlocs = env.wall_locations\n",
    "    # hExt = plan_env_list.hExt_list[0, 0].tolist()\n",
    "    # env_hExts = np.concatenate(env.hExt_list_c, axis=0)\n",
    "    env_hExts = env.wall_hExts\n",
    "    img = mz_renderer.renders( traj_ps , plan_env_list.maze_size, env_wlocs, env_hExts )\n",
    "    imgs_ps.append(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## show the video\n",
    "print('rootdir:', rootdir)\n",
    "# import diffuser.utils.video as duv; reload(duv)\n",
    "from diffuser.utils.video import save_images_to_mp4, read_mp4_to_numpy\n",
    "\n",
    "mp4_name = f'{rootdir}/concave_{env_id}.mp4' # {utils.get_time()}\n",
    "## Save to video\n",
    "save_images_to_mp4(imgs_ps, mp4_name, fps=15, st_sec=1, end_sec=1)\n",
    "## load and save\n",
    "import mediapy as media\n",
    "from mediapy import show_video, read_video\n",
    "show_video(read_mp4_to_numpy(mp4_name), fps=15, width=400)\n",
    "print(f'mp4_name: {mp4_name}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
